import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import sys
import os

# from .resnet import mspanet50
from .swin_transformer import SwinTransformer
from .newcrf_layers import NewCRF
from .uper_crf_head import PSP
# from .FFTNet import FFTBlock,CBAM
########################################################################################################################


class NewCRFDepth(nn.Module):
    """
    Depth network based on neural window FC-CRFs architecture.
    """
    def __init__(self, version=None, inv_depth=False, pretrained=None, 
                    frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs):
        super().__init__()

        self.inv_depth = inv_depth
        self.with_auxiliary_head = False
        self.with_neck = False

        norm_cfg = dict(type='BN', requires_grad=True)
        # norm_cfg = dict(type='GN', requires_grad=True, num_groups=8)

        window_size = int(version[-2:])

        if version[:-2] == 'base':
            embed_dim = 128
            depths = [2, 2, 18, 2]
            num_heads = [4, 8, 16, 32]
            in_channels = [128, 256, 512, 1024]
        elif version[:-2] == 'large': # 启用
            embed_dim = 192 # 192
            depths = [2, 2, 18, 2] #
            num_heads = [6, 12, 24, 48]
            in_channels = [192, 384, 768, 1536]
        elif version[:-2] == 'tiny':
            embed_dim = 96
            depths = [2, 2, 6, 2]
            num_heads = [3, 6, 12, 24]
            in_channels = [96, 192, 384, 768]

        backbone_cfg = dict(
            embed_dim=embed_dim, # 192
            depths=depths,  # 2, 2, 18, 2
            num_heads=num_heads, # 6, 12, 24, 48
            window_size=window_size, # 7
            ape=False,
            drop_path_rate=0.3,
            patch_norm=True,
            use_checkpoint=False,
            frozen_stages=frozen_stages
        )

        embed_dim = 512
        decoder_cfg = dict(
            in_channels=in_channels,
            in_index=[0, 1, 2, 3],
            pool_scales=(1, 2, 3, 6),
            channels=embed_dim,
            dropout_ratio=0.0,
            num_classes=32,
            norm_cfg=norm_cfg,
            align_corners=False
        )

        # # 使用改进后的resnet50进行第一步特征提取
        # self.res_feature = mspanet50()
        # # 对提后的特征降维并3d化
        # self.reduce_dim_1 = nn.Conv2d(256,64,3,padding=1)
        # self.reduce_dim_2 = nn.Conv2d(512,128,3,padding=1)
        # self.reduce_dim_3 = nn.Conv2d(1024,256,3,padding=1)
        # self.reduce_dim_4 = nn.Conv2d(2048,482,3,padding=1)

        self.backbone = SwinTransformer(**backbone_cfg)
        v_dim = decoder_cfg['num_classes']*4
        win = 7
        crf_dims = [128, 256, 512, 1024]
        v_dims = [64, 128, 256, embed_dim]  # [64, 128, 256, 512]

        # # SGNBlock
        # self.sgnblock3 = FFTBlock(dim=1536,mlp_ratio=2.)
        # self.sgnblock2 = FFTBlock(dim=768,mlp_ratio=2.)
        # self.sgnblock1 = FFTBlock(dim=384,mlp_ratio=2.)
        # self.sgnblock0 = FFTBlock(dim=192,mlp_ratio=2.)
        # # CBAM
        # self.cbam3 = CBAM(channels=1536, reduction=16)
        # self.cbam2 = CBAM(channels=768, reduction=16)
        # self.cbam1 = CBAM(channels=384, reduction=16)
        # self.cbam0 = CBAM(channels=192, reduction=16)


        self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32)
        self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16)
        self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8)
        self.crf0 = NewCRF(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, v_dim=v_dims[0], num_heads=4)

        self.decoder = PSP(**decoder_cfg)
        self.disp_head1 = DispHead(input_dim=crf_dims[0])

        self.up_mode = 'bilinear'
        if self.up_mode == 'mask':
            self.mask_head = nn.Sequential(
                nn.Conv2d(crf_dims[0], 64, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 16*9, 1, padding=0))

        self.min_depth = min_depth
        self.max_depth = max_depth

        self.init_weights(pretrained=pretrained)

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone and heads.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        print(f'== Load encoder backbone from: {pretrained}')
        self.backbone.init_weights(pretrained=pretrained)
        self.decoder.init_weights()
        if self.with_auxiliary_head:
            if isinstance(self.auxiliary_head, nn.ModuleList):
                for aux_head in self.auxiliary_head:
                    aux_head.init_weights()
            else:
                self.auxiliary_head.init_weights()

    def upsample_mask(self, disp, mask):
        """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """
        N, _, H, W = disp.shape
        mask = mask.view(N, 1, 9, 4, 4, H, W)
        mask = torch.softmax(mask, dim=2)

        up_disp = F.unfold(disp, kernel_size=3, padding=1)
        up_disp = up_disp.view(N, 1, 9, 1, 1, H, W)

        up_disp = torch.sum(mask * up_disp, dim=2)
        up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
        return up_disp.reshape(N, 1, 4*H, 4*W)

    def forward(self, imgs):

        # imgs: [4,3,480,640]
        # 先经过resnet处理
        # v1.0
        # res_f0,res_f1,res_f2,res_f3 = self.res_feature(imgs)  # [4,256,120,160],[4,512,60,80],[4,1024,30,40],[4,2048,15,20]
        # input__res_f0 = to_3d(self.reduce_dim_1(res_f0))  # 4,19200,64
        # input__res_f1 = to_3d(self.reduce_dim_2(res_f1))  # 4,4800,128
        # input__res_f2 = to_3d(self.reduce_dim_3(res_f2))  # 4,1200,256
        # input__res_f3 = to_3d(self.reduce_dim_4(res_f3))  # 4,300,482

        # 测试res输出维度
        # return input__res_f0,input__res_f1,input__res_f2,input__res_f3

        # v2.0
        # resnet_feats = self.res_feature(imgs)

        feats = self.backbone(imgs)
        # 提取出每个尺度的BCHW，然后在每个尺度后增加FFT和CBAM（后续更换其他attention）
        # feats[]：[4,192,120,160],[4,384,60,80],[4,768,30,40],[4,1536,15,20]

        if self.with_neck:
            feats = self.neck(feats)

        ppm_out = self.decoder(feats)  # [4,512,15,20]

        # # 使用傅里叶门控网络和CBAM来增强特征
        # feats3 = self.sgnblock3(feats[3],spatial_size=(15, 20))
        # feats3 = feats3+self.cbam3(feats3)
        # feats2 = self.sgnblock2(feats[2],spatial_size=(30, 40))
        # feats2 = feats2+self.cbam2(feats2)
        # feats1 = self.sgnblock1(feats[1],spatial_size=(60, 80))
        # feats1 = feats1+self.cbam1(feats1)
        # feats0 = self.sgnblock0(feats[0],spatial_size=(120, 160))
        # feats0 = feats0+self.cbam0(feats0)

        e3 = self.crf3(feats[3], ppm_out) #[4,1024,15,20]
        e3 = nn.PixelShuffle(2)(e3) # e3: [4,256,30,40]
        e2 = self.crf2(feats[2], e3)  #[4,512,30,40]
        e2 = nn.PixelShuffle(2)(e2) # e2: [4,128,60,80]
        e1 = self.crf1(feats[1], e2)  # [4,256,60,80]
        e1 = nn.PixelShuffle(2)(e1) # e1: [4,64,120,160]
        e0 = self.crf0(feats[0], e1) # e0: [4,128,120,160]

        if self.up_mode == 'mask':
            mask = self.mask_head(e0)
            d1 = self.disp_head1(e0, 1)
            d1 = self.upsample_mask(d1, mask)
        else:
            d1 = self.disp_head1(e0, 4)
        # d1:[4,1,480,640]
        depth = d1 * self.max_depth

        return depth

class DispHead(nn.Module):
    def __init__(self, input_dim=100):
        super(DispHead, self).__init__()
        # self.norm1 = nn.BatchNorm2d(input_dim)
        self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1)
        # self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, scale):
        # x = self.relu(self.norm1(x))
        x = self.sigmoid(self.conv1(x))
        if scale > 1:
            x = upsample(x, scale_factor=scale)
        return x



class DispUnpack(nn.Module):
    def __init__(self, input_dim=100, hidden_dim=128):
        super(DispUnpack, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.pixel_shuffle = nn.PixelShuffle(4)

    def forward(self, x, output_size):
        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4]
        # x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4])
        x = self.pixel_shuffle(x)

        return x


def upsample(x, scale_factor=2, mode="bilinear", align_corners=False):
    """Upsample input tensor by a factor of 2
    """
    return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')



if __name__ == '__main__':
    x = torch.randn(4,3,480,640)
    model = NewCRFDepth()
    f0,f1,f2,f3 = model(x)
    print(f0.shape,f1.shape,f2.shape,f3.shape)